{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 基于 PyTorch 的联邦学习自定义 loss function 教程\n",
    "## 引言\n",
    "### 背景\n",
    "在联邦学习中，尤其是监督学习中，我们常常需要使用损失函数监督模型的训练；通过之前的[入门教程](https://www.secretflow.org.cn/docs/secretflow/latest/zh-Hans/tutorial/Federated_Learning_with_Pytorch_backend), 我们已经展示如何通过 `secretflow_fl.ml.nn.core.torch.TorchModel` 调用 `torch.nn.CrossEntropyLoss` ，依此类推，我们可以调用 [torch.nn loss function](https://pytorch.org/docs/stable/nn.html#loss-functions) 中的任意损失函数。然而，当我们需要根据自己的任务自定义损失函数时，需要怎样做呢？本教程将回答这一问题。\n",
    "### 教程提醒\n",
    "注意，本自定义 loss function 教程主要关注输入形式为$(\\hat{y},y)$的损失函数，而不讨论超出此范围的自定义损失函数。\n",
    "具体到本教程，本教程将给出如何自定义实现\n",
    "$$\n",
    "Loss(\\hat{y},y) = 0.8*CEL(\\hat{y},y) + 0.2*MSE(\\hat{y},y)\n",
    "$$\n",
    "其中，$CEL$ 表示 [cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss) ，$MSE$ 表示[mean squared error](https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss)，对于其他的损失函数组合形式，您可以自行定义和组合。\n",
    "\n",
    "再度提醒，本教程只是作为教程示例，展示代码的实现，而不作为实际生产应用的模型训练指导。\n",
    "\n",
    "让我们开始吧！"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 基础教程\n",
    "为突出重点，简化教程，本教程将以 [使用Pytorch后端来进行联邦学习](https://www.secretflow.org.cn/docs/secretflow/latest/zh-Hans/tutorial/Federated_Learning_with_Pytorch_backend) 为基础，重点突出自定义损失函数的做法。所以，为了让代码能够顺利运行，让我们先把之前的代码复制过来。因此如果您对原教程非常熟悉，则不需要再阅读这部分代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The version of SecretFlow: 1.4.0.dev20231225\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-01-10 09:34:35,376\tINFO worker.py:1538 -- Started a local Ray instance.\n"
     ]
    }
   ],
   "source": [
    "import secretflow as sf\n",
    "\n",
    "# Check the version of your SecretFlow\n",
    "print('The version of SecretFlow: {}'.format(sf.__version__))\n",
    "\n",
    "# In case you have a running secretflow runtime already.\n",
    "sf.shutdown()\n",
    "\n",
    "sf.init(['alice', 'bob', 'charlie'], address='local')\n",
    "alice, bob, charlie = sf.PYU('alice'), sf.PYU('bob'), sf.PYU('charlie')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-01-10 09:34:37.234123: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /content/conda-env/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64\n",
      "2024-01-10 09:34:38.012003: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /content/conda-env/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64\n",
      "2024-01-10 09:34:38.012080: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /content/conda-env/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64\n",
      "2024-01-10 09:34:38.012088: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
     ]
    }
   ],
   "source": [
    "from secretflow_fl.ml.nn.core.torch import (\n",
    "    metric_wrapper,\n",
    "    optim_wrapper,\n",
    "    BaseModule,\n",
    "    TorchModel,\n",
    ")\n",
    "from secretflow_fl.ml.nn import FLModel\n",
    "from torchmetrics import Accuracy, Precision\n",
    "from secretflow.security.aggregation import SecureAggregator\n",
    "from secretflow_fl.utils.simulation.datasets_fl import load_mnist\n",
    "from torch import nn, optim\n",
    "from torch.nn import functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvNet(BaseModule):\n",
    "    \"\"\"Small ConvNet for MNIST.\"\"\"\n",
    "\n",
    "    def __init__(self):\n",
    "        super(ConvNet, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 3, kernel_size=3)\n",
    "        self.fc_in_dim = 192\n",
    "        self.fc = nn.Linear(self.fc_in_dim, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(F.max_pool2d(self.conv1(x), 3))\n",
    "        x = x.view(-1, self.fc_in_dim)\n",
    "        x = self.fc(x)\n",
    "        return F.softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自定义损失函数\n",
    "如前所述，我们将自定义损失函数：\n",
    "$$\n",
    "Loss(\\hat{y},y) = 0.8*CEL(\\hat{y},y) + 0.2*MSE(\\hat{y},y)\n",
    "$$\n",
    "其中，$CEL$ 表示 [cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss)，$MSE$ 表示 [mean squared error](https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss)\n",
    "\n",
    "为实现这一个自定义损失函数，我们存在两种实现方式，一种是继承[torch.nn.module](https://github.com/pytorch/pytorch/tree/main/torch/nn/modules) 的类，另外一种直接定义函数。\n",
    "### 继承 torch.nn.module\n",
    "#### 继承介绍\n",
    "我们需要自行编写一个继承自 [torch.nn.module](https://github.com/pytorch/pytorch/tree/main/torch/nn/modules) 的类，而且至少实现两个基础的函数：`__init__` 和 `forward`，其中:\n",
    "- `__init__` 执行该类的初始化部分代码，本教程我们对基础损失函数 `CrossEntropyLoss` 和 `MSELoss` 进行了初始化的操作\n",
    "- `forward`  执行该类的调用时的运算代码，也就是自定义损失函数的运算逻辑，此处我们对上面所提及的自定义函数进行了实现"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 实现自定义类"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomLossFunction(nn.Module):\n",
    "    def __init__(self, *args, **kwargs) -> None:\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='sum')\n",
    "        self.mse_loss = nn.MSELoss()\n",
    "\n",
    "    def forward(self, input, target):\n",
    "        return 0.8 * self.cross_entropy_loss(input, target) + 0.2 * self.mse_loss(\n",
    "            input, target\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 直接自定义损失函数\n",
    "#### 自定义函数介绍\n",
    "我们也可以直接定义损失函数，对于同一实现，直接实现如下：\n",
    "#### 自定义函数实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def my_loss_function(input, target):\n",
    "    cross_entropy_loss = nn.CrossEntropyLoss(reduction='sum')\n",
    "    mse_loss = nn.MSELoss()\n",
    "    return 0.8 * cross_entropy_loss(input, target) + 0.2 * mse_loss(input, target)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 指定自定义损失函数\n",
    "### 继承 torch.nn.module\n",
    "当我们通过继承 torch.nn.module 实现自定义函数时，我们可以在下面的单元格里，通过\n",
    "``\n",
    "loss_fn = CustomLossFunction\n",
    "``\n",
    "指定我们自定义的损失函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# here we use the loss function we defined above\n",
    "loss_fn = CustomLossFunction\n",
    "\n",
    "optim_fn = optim_wrapper(optim.Adam, lr=1e-2)\n",
    "model_def = TorchModel(\n",
    "    model_fn=ConvNet,\n",
    "    loss_fn=loss_fn,\n",
    "    optim_fn=optim_fn,\n",
    "    metrics=[\n",
    "        metric_wrapper(Accuracy, task=\"multiclass\", num_classes=10, average='micro'),\n",
    "        metric_wrapper(Precision, task=\"multiclass\", num_classes=10, average='micro'),\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 直接自定义损失函数\n",
    "当我们通过直接自定义损失函数实现时，我们可以在下面的单元格里，通过\n",
    "``\n",
    "loss_fn = my_loss_function\n",
    "``\n",
    "指定我们自定义的损失函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# here we use the loss function we defined above\n",
    "loss_fn = my_loss_function\n",
    "\n",
    "optim_fn = optim_wrapper(optim.Adam, lr=1e-2)\n",
    "model_def = TorchModel(\n",
    "    model_fn=ConvNet,\n",
    "    loss_fn=loss_fn,\n",
    "    optim_fn=optim_fn,\n",
    "    metrics=[\n",
    "        metric_wrapper(Accuracy, task=\"multiclass\", num_classes=10, average='micro'),\n",
    "        metric_wrapper(Precision, task=\"multiclass\", num_classes=10, average='micro'),\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "通过本教程，我们将学会如何基于 PyTorch 在SecretFlow 中自定义实现输入形式为 $(\\hat{y},y)$ 的损失函数。"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "multi-task-learning-example-pytorch.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
